-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interpretability module: sparse linear models via LASSO #120
Conversation
|
||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import mlcolvar.utils.plot |
Check notice
Code scanning / CodeQL
Unused import Note
import mlcolvar.utils.plot | ||
|
||
try: | ||
import sklearn |
Check notice
Code scanning / CodeQL
Unused import Note
@@ -1,8 +1,10 @@ | |||
import numpy as np | |||
import torch | |||
from matplotlib import patches as mpatches | |||
import matplotlib.pyplot as plt | |||
import mlcolvar.utils.plot |
Check notice
Code scanning / CodeQL
Unused import Note
@@ -0,0 +1,7 @@ | |||
import pytest |
Check notice
Code scanning / CodeQL
Unused import Note test
try: | ||
import sklearn | ||
except ImportError: | ||
print('The lasso module requires scikit-learn as additional dependency.') |
Check notice
Code scanning / CodeQL
Use of a print statement at module level Note
…olvar into interpretability
fig, axs = plt.subplots(n_feat, 1, figsize=(3, 3*n_feat)) | ||
|
||
plt.suptitle('Features distribution') | ||
init_ax = True |
Check notice
Code scanning / CodeQL
Unused local variable Note
ax.set_xlim(0, None) | ||
if n_feat != len(axs): | ||
raise ValueError(f'Number of features ({len(features)}) != number of axis ({len(axs)})') | ||
init_ax = False |
Check notice
Code scanning / CodeQL
Unused local variable Note
I have put everything into a new will merge it soon |
Description
Add sparse linear models optimized via LASSO as tools for interpreting the CVs and/or the resulting states, as done here: https://pubs.acs.org/doi/abs/10.1021/acs.jctc.2c00393.
I started from the notebook that @pietronvll and I did. We implemented both the classifier case (as done in stateinterpreter) and also the regression one. A few changes:
For both the regression and classification the signature is (almost) the same, with both returning the optimized estimator together with the list of non-zero features and their coefficients. I also did separate functions to plot the results (coefficient paths, score and number of features).
Todos
Notable points that this PR has either accomplished or will accomplish.
lasso_classification
(based onsckitlearn.LogisticRegressionCV
)lasso_regression
(based onsckitlearn.LassoCV
)Tutorials
Work in progress
Questions
utils.lasso
. However, since there is already also the sensitivity analysis contained inutils.explain
we might move all these functions into a new module calledexplain
?Status